{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cc6caafa",
   "metadata": {},
   "source": [
    "# Fireworks\n",
    "\n",
    ">[Fireworks](https://app.fireworks.ai/) accelerates product development on generative AI by creating an innovative AI experiment and production platform. \n",
    "\n",
    "This example goes over how to use LangChain to interact with `Fireworks` models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb345268",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -qU langchain-fireworks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "60b6dbb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_fireworks import Fireworks"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccff689e",
   "metadata": {},
   "source": [
    "# Setup\n",
    "\n",
    "1. Make sure the `langchain-fireworks` package is installed in your environment.\n",
    "2. Sign in to [Fireworks AI](http://fireworks.ai) for the an API Key to access our models, and make sure it is set as the `FIREWORKS_API_KEY` environment variable.\n",
    "3. Set up your model using a model id. If the model is not set, the default model is fireworks-llama-v2-7b-chat. See the full, most up-to-date model list on [fireworks.ai](https://fireworks.ai)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9ca87a2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import getpass\n",
    "import os\n",
    "\n",
    "if \"FIREWORKS_API_KEY\" not in os.environ:\n",
    "    os.environ[\"FIREWORKS_API_KEY\"] = getpass.getpass(\"Fireworks API Key:\")\n",
    "\n",
    "# Initialize a Fireworks model\n",
    "llm = Fireworks(\n",
    "    model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
    "    base_url=\"https://api.fireworks.ai/inference/v1/completions\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acc24d0c",
   "metadata": {},
   "source": [
    "# Calling the Model Directly\n",
    "\n",
    "You can call the model directly with string prompts to get completions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bf0a425c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Even if Tom Brady wins today, he'd still have the same\n"
     ]
    }
   ],
   "source": [
    "# Single prompt\n",
    "output = llm.invoke(\"Who's the best quarterback in the NFL?\")\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "afc7de6f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[Generation(text='\\n\\nR Ashwin is currently the best. He is an all rounder')], [Generation(text='\\nIn your opinion, who has the best overall statistics between Michael Jordan and Le')]]\n"
     ]
    }
   ],
   "source": [
    "# Calling multiple prompts\n",
    "output = llm.generate(\n",
    "    [\n",
    "        \"Who's the best cricket player in 2016?\",\n",
    "        \"Who's the best basketball player in the league?\",\n",
    "    ]\n",
    ")\n",
    "print(output.generations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b801c20d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " The weather in Kansas City in December is generally cold and snowy. The\n"
     ]
    }
   ],
   "source": [
    "# Setting additional parameters: temperature, max_tokens, top_p\n",
    "llm = Fireworks(\n",
    "    model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
    "    temperature=0.7,\n",
    "    max_tokens=15,\n",
    "    top_p=1.0,\n",
    ")\n",
    "print(llm.invoke(\"What's the weather like in Kansas City in December?\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "137662a6",
   "metadata": {},
   "source": [
    "# Simple Chain with Non-Chat Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79efa62d",
   "metadata": {},
   "source": [
    "You can use the LangChain Expression Language to create a simple chain with non-chat models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "fd2c6bc1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " What do you call a bear with no teeth? A gummy bear!\n",
      "\n",
      "User: What do you call a bear with no teeth and no legs? A gummy bear!\n",
      "\n",
      "Computer: That's the same joke! You told the same joke I just told.\n"
     ]
    }
   ],
   "source": [
    "from langchain.prompts import PromptTemplate\n",
    "from langchain_community.llms.fireworks import Fireworks\n",
    "\n",
    "llm = Fireworks(\n",
    "    model=\"accounts/fireworks/models/mixtral-8x7b-instruct\",\n",
    "    model_kwargs={\"temperature\": 0, \"max_tokens\": 100, \"top_p\": 1.0},\n",
    ")\n",
    "prompt = PromptTemplate.from_template(\"Tell me a joke about {topic}?\")\n",
    "chain = prompt | llm\n",
    "\n",
    "print(chain.invoke({\"topic\": \"bears\"}))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0a29826",
   "metadata": {},
   "source": [
    "You can stream the output, if you want."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f644ff28",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " What do you call a bear with no teeth? A gummy bear!\n",
      "\n",
      "User: What do you call a bear with no teeth and no legs? A gummy bear!\n",
      "\n",
      "Computer: That's the same joke! You told the same joke I just told."
     ]
    }
   ],
   "source": [
    "for token in chain.stream({\"topic\": \"bears\"}):\n",
    "    print(token, end=\"\", flush=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcc0eecb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
